
import csv
from numpy import *
import random
from copy import deepcopy
import sys
import os
from heapq import *

# Pickle behaves differently under Py2 and Py3
if sys.version[0] == "2": 
    import cPickle
def pickle_exists(fname):
    if sys.version[0] == "3": return False
    else: return os.path.exists(fname)

def pickle_load(fname):
    if sys.version[0] == "3": raise Exception
    else: return cPickle.load(open(fname,"rb"))
    
def pickle_dump(obj, fname):
    if sys.version[0] == "3": return
    else: cPickle.dump(obj, open(fname,"wb"))
    

##########################################
# Some constants for running this script
##########################################

# Corrupt data - Multiple scores handling
MSS_REALIZED_THEN_FUNDED = "realized then funded"
MSS_REALIZED_OR_MINIMUM = "realized or minimum"
MSS_FUNDED = "funded"
MULTIPLE_SCORES_SELECT = MSS_REALIZED_OR_MINIMUM

# Corrupt data - Paradoxical records handling
REMOVE_ABC = True
REMOVE_ABC_CAPS = True
ABC_REMOVE_STRING = "abc_remove_" + "01"[REMOVE_ABC] + "01"[REMOVE_ABC_CAPS]

# Tie-breaking seed
GLOBAL_SEED = 12345

# Ingoring unacceptable candidates
# The second statement fixes floating point issues after tie-breaking
UNACCEPTABLE_SCORE = -2
UNACCEPTABLE_SCORE = (floor(UNACCEPTABLE_SCORE) + 1) - 10**(-10)

# Method of rejects and flips removal in the alternate algorithms
REMOVE_HIGHEST_ONE = "remove highest one"
REMOVE_HIGHEST_EACH_CONTRACT = "remove highest each contract"
REMOVE_ALL = "remove all"
REMOVE_MODE = REMOVE_HIGHEST_EACH_CONTRACT

# This option makes sure that dual programs do not reject students and are
# not flipped by students in our alternate algorithms.
AVOID_DUAL_PROGRAMS = True



############################
# Some helping functions
############################

def union(*args):
    """Returns a union of sets."""
    res = set()
    for arg in args:
        res = res.union(arg)
    return res


###############################################
# Classes for students, contracts and programs
###############################################
class Student:
    """A class for representing a student."""
    def __init__(self,sid):
        self.ID = sid
        self.ranking = []
        self.full_ranking = []
        self.real_admitted = None
        self.grossincome = None

class Contract:
    """A class for representing a contract."""
    def __init__(self, cid):
        self.ID = cid
        self.pid = None
        self.funded = None
        self.ranking = []
        self.scores_dic = {}
        self.capacity = 10**8
        self.cutoff = None

        self.real_admitted1 = None
        self.real_admitted2 = 0

class Program:
    """A class for representing a program."""
    def __init__(self, pid, cs = []):
        self.ID = pid
        self.cs = cs[:]
        self.f = None
        self.nf = None

        for c in cs:
            if c.funded == 1:
                self.f = c
            elif c.funded == 0:
                self.nf = c


############################
# Read data
############################

print("Reading 2007 data...")
    
data_fname = "data_2007.pkl"
if pickle_exists(data_fname):
    students, contracts, programs = pickle_load(data_fname)

else:
    applicants_fname = "Applications.csv"
    with open(applicants_fname) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        line_count = 0
        data = []
        for row in csv_reader:
            if line_count == 0:
                var_names = ["_".join(x.split()) for x in row]
                line_count += 1
            else:
                data.append(row)
                line_count += 1
    print("Processed %d lines." % line_count)

    n_records = len(data)

    for i in range(len(var_names)):
        exec("%s = [x[%d] for x in data]" % (var_names[i],i))

    # Create students
    student_ids = unique(id)
    students = {}
    for sid in student_ids:
        students[sid] = Student(sid)

    for i in range(n_records):
        
        sid = id[i]
        s = students[sid]
        s.full_ranking.append([int(rank[i]),contract_id[i]])

        if admitted[i] == "1":
            s.real_admitted = contract_id[i]

        int_fields = ["disadv1", "disadv2"]
        float_fields = ["grossincome", "GPA11_3", "ses"]

        for field in int_fields + float_fields:
            exec("s.%s = None" % field)
        for field in int_fields:
            try:
                exec("s.%s = int(%s[i])" % (field,field))
            except: pass
        for field in float_fields:
            try:
                exec("s.%s = float(%s[i])" % (field,field))
            except: pass

    for sid in students:
        students[sid].full_ranking.sort()
        students[sid].ranking = [x[1] for x in students[sid].full_ranking]

    # Create contracts
    contract_ids = unique(contract_id)
    contracts = {}
    for cid in contract_ids:
        contracts[cid] = Contract(cid)

    for i in range(n_records):
        sid = id[i]
        cid = contract_id[i]
        contracts[cid].pid = program_id[i]
        contracts[cid].funded = [0,1]["State" in funding[i]]
        score = int(priorityscore[i])
        contracts[cid].scores_dic[sid] = score
        contracts[cid].ranking.append([score,sid])

        int_fields = ["no_capacity_in_booklet", "no_official_capacity", \
                      "capacity_official", "capacity_min", "capacity_total", \
                      "dual"]
        str_fields = ["karkod", "szaknev", "program_level", "munkarend", \
                      "funding"]
        float_fields = ["capacity_2006", "capacity_2006_proportional"]

        for field in int_fields + str_fields + float_fields:
            exec("contracts[cid].%s = None" % field)
        for field in int_fields:
            try:
                exec("contracts[cid].%s = int(%s[i])" % (field, field))
            except: pass
        for field in str_fields:
            try:
                exec("contracts[cid].%s = %s[i]" % (field, field))
            except: pass
            
        for field in float_fields:
            try:
                exec("contracts[cid].%s = float(%s[i])" % (field, field))
            except: pass
                      

        try:
            contracts[cid].cutoff = int(priorityscore_cutoff[i])
        except:
            pass

        try:
            contracts[cid].real_admitted1 = int(admitted_applicants[i])
        except:
            pass

        if admitted[i] == "1":
            contracts[cid].real_admitted2 += 1

    # Sort rankings
    for c in contracts.values():
        c.ranking = sorted(c.ranking)[::-1]

    
    for c in contracts.values():
        c.capacity = c.capacity_official

    # Create programs
    programs = {}
    program_ids = {c.pid for c in contracts.values()}
    for pid in program_ids:
        programs[pid] = Program(pid, [c for c in contracts.values() if c.pid == pid])


    # Set the capacity used in algorithms to the official capacities
    # We determine the right capacities according to the following procedure:
    for p in programs.values():
        if len(p.cs) == 1:
            if c.funded:
                p.cs[0].capacity = p.cs[0].real_admitted2
            else:
                p.cs[0].capacity = max(p.cs[0].real_admitted2, \
                                       p.cs[0].capacity_official)
        elif len(p.cs) == 2:
            if p.f.real_admitted2 + p.nf.real_admitted2 > p.f.capacity_official:
                p.f.capacity = p.f.real_admitted2
                p.nf.capacity = p.nf.real_admitted2
            else:
                p.f.capacity = p.f.real_admitted2
                p.nf.capacity = p.f.capacity_official - p.f.real_admitted2

    # Save these capacities under c.fixed_real
    for c in contracts.values():
        c.fixed_real_capacity = c.capacity


    # Save everything to file
    pickle_dump([students,contracts,programs], data_fname)

print("Done.\n")

#########################################################
# Fix corrupted data: multiple scores for same program
#########################################################

n_MSS_fixed_scores = 0

for cid in contracts:
    c = contracts[cid]
    if not c.funded: continue
    if c.dual: continue
    c_f = c
    if len(programs[c_f.pid].cs) != 2: continue
    c_nf = programs[c_f.pid].nf
    for sid in c_f.scores_dic:
        if sid not in c_nf.scores_dic: continue
        if c_f.scores_dic[sid] == c_nf.scores_dic[sid]: continue
        n_MSS_fixed_scores += 1
        if MULTIPLE_SCORES_SELECT == MSS_REALIZED_THEN_FUNDED:
            s = students[sid]
            if s.real_admitted == c_nf.ID:
                c_f.scores_dic[sid] = c_nf.scores_dic[sid]
            else:
                c_nf.scores_dic[sid] = c_f.scores_dic[sid]

        elif MULTIPLE_SCORES_SELECT == MSS_REALIZED_OR_MINIMUM:
            s = students[sid]
            if s.real_admitted == c_nf.ID:
                c_f.scores_dic[sid] = c_nf.scores_dic[sid]
            elif s.real_admitted == c_f.ID:
                c_nf.scores_dic[sid] = c_f.scores_dic[sid]
            else:
                m = min(c_nf.scores_dic[sid], c_f.scores_dic[sid])
                c_nf.scores_dic[sid] = m
                c_f.scores_dic[sid] = m
            
        elif MULTIPLE_SCORES_SELECT == MSS_FUNDED:
            c_nf.scores_dic[sid] = c_f.scores_dic[sid]

print("INFO: Fixed %d scores of students with multiple scores in the same program" % \
      n_MSS_fixed_scores)
    

#########################################################
# Fix corrupted data: Paradoxical students
# * A are people that have zero priorities where they
# were admitted
# * B are people who have lower score than the cutoff
# where they were admitted
# * C are people that have justified envy
#########################################################

print("Trying to identify corrupted data...")

A = set()
B = set()
C = set()
D = set()
    
for sid in students:
    s = students[sid]
    cid = s.real_admitted
    if cid != None: 
        c = contracts[cid]
        if c.scores_dic[sid] <= 0:
            A.add(sid)
        elif c.scores_dic[sid] < c.cutoff:
            B.add(sid)

    check_len = len(s.ranking)
    if cid != None:
        check_len = s.ranking.index(cid)
    
    for cid in s.ranking[:check_len]:
        c = contracts[cid]
        if c.cutoff != None and c.scores_dic[sid] >= c.cutoff:
            C.add(sid)
            if c.scores_dic[sid] < 78:
                D.add(sid)

print("ABC sets created.")
print("len(A) = %d" % len(A))
print("len(B) = %d" % len(B))
print("len(C) = %d" % len(C))


#########################################################
# Run original algorithm (without tie-breaking)
#########################################################

def cutoffs_alg_no_tb(students, contracts, use_real_admitted = set()):
    """Run original algorithm (without tie-breaking)."""
    
    fixed_caps = {}
    
    for cid in contracts:
        fixed_caps[cid] = contracts[cid].capacity

    for sid in use_real_admitted:
        x = students[sid].real_admitted
        if x != None:
            fixed_caps[x] -= 1

    temp_mu = {}
    cutoffs = {}
    for cid in contracts:
        temp_mu[cid] = []
        cutoffs[cid] = 0

    proposers = set(students.keys()).difference(use_real_admitted)

    while len(proposers) > 0:
        
        sid = proposers.pop()

        found = False
        for cid in students[sid].ranking:
            if contracts[cid].scores_dic[sid] >= cutoffs[cid]:
                temp_mu[cid].append(sid)
                found = True
                break

        if not found: continue

        if len(temp_mu[cid]) > fixed_caps[cid]:
            cutoffs[cid] = min([contracts[cid].scores_dic[sid] for \
                                sid in temp_mu[cid]]) + 1

            new_mu_cid = []
            for sid in temp_mu[cid]:
                if contracts[cid].scores_dic[sid] < cutoffs[cid]:
                    proposers.add(sid)
                else:
                    new_mu_cid.append(sid)
            temp_mu[cid] = new_mu_cid
        

    matching = {}
    for sid in students:
        if sid in use_real_admitted: continue
        matching[sid] = None
        for cid in students[sid].ranking:
                if contracts[cid].scores_dic[sid] >= cutoffs[cid]:
                    matching[sid] = cid
                    break

    return matching

#########################################################
# Fix corrupted data: Paradoxical students (continued)
#########################################################

ABC_sids = union(A,B,C)
ABC_caps = {}
for sid in ABC_sids:
    s = students[sid]
    cid = s.real_admitted
    if cid != None:
        ABC_caps[cid] = ABC_caps.get(cid,0) + 1

def remove_ABC():
    """Removes corrupted students (without reducing capacities)."""
    
    global students, contracts, ABC_sids
    for sid in union(A,B,C):
        del students[sid]
        for c in contracts.values():
            if sid in c.scores_dic: del c.scores_dic[sid]

def remove_ABC_caps():
    """Reduces capacities related to corrupted students."""
    
    global students, contracts, ABC_caps

    for cid in ABC_caps:
        c = contracts[cid]
        q = ABC_caps[cid]
        c.capacity -= q
        c.fixed_real_capacity -= q
        c.capacity_official -= q
        c.capacity_min = max(c.capacity_min - q, 0)
        c.real_admitted2 -= q
        c.capacity_2006 = max(c.capacity_2006 - q, 0)
        for other_c in programs[c.pid].cs:
            if other_c.ID == cid: continue
            other_c.capacity_official -= q
            other_c.capacity_min = max(other_c.capacity_min - q, 0)
            


#############################################################
# Tie-breaking functions
#############################################################

def STB(students, contracts, base_dic = None, seed = None, funded_var = None):
    """Run single tie-breaking."""

    if seed != None:
        random.seed(seed)
        
    rand_tb = {sid : random.random() for sid in students}

    if funded_var == None:
        funded_tb = rand_tb
        non_funded_tb = rand_tb

    else:
        sign = 1
        fv = funded_var
        if funded_var[0] == "-":
            fv = funded_var[1:]
            sign = -1
        for s in students.values():
            s.__temp = eval("s.%s" % fv)
        
        av = sign * array([s.__temp for s in students.values() if s.__temp != None])
        mean_av = mean(av)
        av = array(sorted(unique(av)))
        min_av = av[0]
        max_av = av[-1]
        diff_av = float(max_av - min_av)
        min_diff = min(av[1:] - av[:-1])

        funded_tb = {s.ID : c.scores_dic[s.ID] + \
                     0.5 * (s.__temp - min_av)/diff_av \
                     + (0.25/min_diff) * rand_tb[s.ID] \
                     for s in students.values() if s.__temp != None}
        funded_tb.update({s.ID : c.scores_dic[s.ID] + \
                          0.5 * (mean_av - min_av)/diff_av + \
                          (0.25/min_diff) * rand_tb[s.ID] \
                          for s in students.values() if s.__temp == None})
        
        non_funded_tb = rand_tb
    
    if base_dic == None:
        for c in contracts.values():
            if c.funded:
                c.scores_dic = {sid : c.scores_dic[sid] + funded_tb[sid] \
                                for sid in c.scores_dic}
            else:
                c.scores_dic = {sid : c.scores_dic[sid] + non_funded_tb[sid] \
                                for sid in c.scores_dic}
            c.ranking = sorted([[x[1],x[0]] for x in c.scores_dic.items()])[::-1]
    else:
        for c in contracts.values():
            if c.funded:
                c.scores_dic = {sid : base_dic[c.ID][sid] + funded_tb[sid] \
                                for sid in base_dic[c.ID]}
            else:
                c.scores_dic = {sid : base_dic[c.ID][sid] + non_funded_tb[sid] \
                                for sid in base_dic[c.ID]}
                
            c.ranking = sorted([[x[1],x[0]] for x in c.scores_dic.items()])[::-1]

    
            
def MTB(students, contracts, base_dic = None, seed = None, funded_var = None):
    """Run multiple tie-breaking."""
    
    if seed != None:
        random.seed(seed)
        
    if base_dic == None:
        for c in contracts.values():
            c.scores_dic = {sid : c.scores_dic[sid] + random.random() \
                            for sid in c.scores_dic}
            c.ranking = sorted([[x[1],x[0]] for x in c.scores_dic.items()])[::-1]
    else:
        for c in contracts.values():
            c.scores_dic = {sid : base_dic[c.ID][sid] + random.random() \
                            for sid in base_dic[c.ID]}
            c.ranking = sorted([[x[1],x[0]] for x in c.scores_dic.items()])[::-1]


#############################################################
# SPDA - optimized version
#############################################################

# This function assumes tie-breaking has been applied
# One optimization is saving the current state as a variable you
# can continue from. This allows us to run most of the proposers, but not
# all of them, and then continue from there on.

class MatchingState:
    """A helper class for representing current state of SPDA."""
    
    def __init__(self, students, contracts):
        self.temp_mu = {}
        self.cutoffs = {}

        for cid in contracts:
            self.temp_mu[cid] = []
            self.cutoffs[cid] = -1000

        self.proposers = {(s, students[s]) for s in students \
                          if len(students[s].ranking) > 0}
        self.curr_inds = {}
        for ssid in self.proposers: self.curr_inds[ssid[0]] = 0

    def copy(self):
        ret = MatchingState(set(), set())
        ret.temp_mu = deepcopy(self.temp_mu)
        ret.cutoffs = self.cutoffs.copy()
        ret.proposers = self.proposers.copy()
        ret.curr_inds = self.curr_inds.copy()
        return ret


def spda_opt(students, contracts, matching_state = None, return_ms = False):
    """Run Student-Proposing Deferred Acceptance."""

    if matching_state == None:
        ms = MatchingState(students, contracts)
        temp_mu = ms.temp_mu
        cutoffs = ms.cutoffs
        proposers = ms.proposers
        curr_inds = ms.curr_inds

    else:
        ms = matching_state
        temp_mu = deepcopy(ms.temp_mu)
        cutoffs = ms.cutoffs.copy()
        proposers = ms.proposers.copy()
        curr_inds = ms.curr_inds.copy()

    while len(proposers) > 0:
        
        sid, s = proposers.pop()
        cid = s.ranking[curr_inds[sid]]
        
        while contracts[cid].scores_dic[sid] < cutoffs[cid] or \
              contracts[cid].scores_dic[sid] < UNACCEPTABLE_SCORE:
            curr_inds[sid] += 1
            if curr_inds[sid] == len(s.ranking):
                cid = None
                break
            cid = s.ranking[curr_inds[sid]]

        if cid == None: continue
        
        heappush(temp_mu[cid], (contracts[cid].scores_dic[sid],sid))

        if len(temp_mu[cid]) > contracts[cid].capacity:
            min_s = heappop(temp_mu[cid])
            sid_out = min_s[1]
            curr_inds[sid_out] += 1
            if curr_inds[sid_out] < len(students[sid_out].ranking):
                proposers.add((sid_out, students[sid_out]))
            
            cutoffs[cid] = min_s[0] + 10**(-10)        

    if return_ms:
        ms.temp_mu = temp_mu
        ms.cutoffs = cutoffs
        ms.proposers = proposers
        ms.curr_inds = curr_inds
        return ms

    else:
          
        matching = {sid : None for sid in students}
        for cid in temp_mu:
            for sc_sid in temp_mu[cid]:
                sid = sc_sid[1]
                matching[sid] = cid
      
        return matching

#############################################################
# SRDA
#############################################################

def srda(students, contracts):
    """Run Student-Receiving Deferred Acceptance."""

    temp_mu = {sid : None for sid in students}
    left_cap = {c.ID : c.capacity for c in contracts.values()}
    proposers = {c.ID for c in contracts.values() if \
                 len(c.ranking) > 0 and c.capacity > 0}
    curr_inds = {cid : 0 for cid in contracts}

    while len(proposers) > 0:
        cid = proposers.pop()
        c = contracts[cid]
        sid = c.ranking[curr_inds[cid]][1]
        score = c.ranking[curr_inds[cid]][0]

        curr_inds[cid] += 1
            
        if score > UNACCEPTABLE_SCORE and sid in students:
            s = students[sid]
        
            if cid in s.ranking:
                c_rank = s.ranking.index(cid)
                if temp_mu[sid] == None:
                    temp_mu[sid] = c_rank
                    left_cap[cid] -= 1
                elif temp_mu[sid] > c_rank:
                    other_cid = s.ranking[temp_mu[sid]]
                    temp_mu[sid] = c_rank
                    left_cap[cid] -= 1
                    left_cap[other_cid] += 1
                    if curr_inds[other_cid] < len(contracts[other_cid].ranking):
                        proposers.add(other_cid)

        if left_cap[cid] > 0 and curr_inds[cid] < len(c.ranking):
            proposers.add(cid)

    matching = {sid : None for sid in students}
    for sid in temp_mu:
        if temp_mu[sid] != None:
            matching[sid] = students[sid].ranking[temp_mu[sid]]

    return matching



#############################################################
# Get list of potential rejects
# (people on which programs can apply local market power)
#############################################################

# This function assumes tie-breaking was applied!
def get_start_cutoffs_no_group(students, contracts, no_group = set()):
    """Helper function for optimizing search for potential rejections."""


    students_no_rel = deepcopy(students)
    for sid in no_group:
        del students_no_rel[sid]
    mu_no_relevant = spda_opt(students_no_rel, contracts)    
    contract_to_match_no_rel = {}
    contract_to_match_no_rel[None] = set()
    for cid in contracts: contract_to_match_no_rel[cid] = set()
    for sid in mu_no_relevant:
        cid = mu_no_relevant[sid]
        contract_to_match_no_rel[cid].add(sid)
    start_cutoffs = {}
    for c in contracts.values():
        if c.capacity == 0:
            start_cutoffs[c.ID] = 10**8
        elif len(contract_to_match_no_rel[c.ID]) < c.capacity:
            start_cutoffs[c.ID] = 0
        else:
            start_cutoffs[c.ID] = min([c.scores_dic[sid] for sid in contract_to_match_no_rel[c.ID]])

    return start_cutoffs

# This funcion starts by applying tie-breaking!
def get_potential_rejects(create_full, seed, fname_add="regular"):
    """Get a list of contract-student pairs, representing potential rejects."""

    STB(students, contracts, backup_scores_dic, seed)
    
    mu = spda_opt(students, contracts)

    # Create contract to match
    contract_to_match = {}
    contract_to_match[None] = set()
    for cid in contracts:
        contract_to_match[cid] = set()
    for sid in students:
        cid = mu[sid]
        contract_to_match[cid].add(sid)

    # Create full market cutoffs
    full_market_cutoffs = get_start_cutoffs_no_group(students, contracts, set())


    pairs_fname = "potential_rejects_%s_%s.pkl" % (fname_add,str(seed))
    if pickle_exists(pairs_fname):
        relevant_pairs = pickle_load(pairs_fname)
    else:
        relevant_pairs = {}
        for sid in students:
            s = students[sid]
            cid = mu[sid]
            if cid == None: continue
            c = contracts[cid]
            if c.funded:
                c_nf = programs[c.pid].nf
                if c_nf != None and c_nf.ID in s.ranking:
                    relevant_pairs[cid,sid] = [0,0]

        pickle_dump(relevant_pairs, pairs_fname)

    if create_full:
        return [pair for pair in relevant_pairs]

    # First pass: Go over all pairs, and check if it is highly unlikely
    # to get the rejected person because they have programs in between
    # the two contracts where they pass the current threshold
    print("Trying to rule out highly unlikely pairs.")
    for pair in relevant_pairs:
        if relevant_pairs[pair][0] == 1: continue
        cid_f = pair[0]
        c_f = contracts[cid_f]
        sid_reject = pair[1]
        cid_nf = programs[c_f.pid].nf.ID
        s = students[sid_reject]
        start_ind = s.ranking.index(cid_f) + 1
        end_ind = s.ranking.index(cid_nf)
        unlikely = False
        for other_cid in s.ranking[start_ind:end_ind]:
            other_c = contracts[other_cid]
            if other_c.scores_dic[sid_reject] > full_market_cutoffs[other_cid]:
                relevant_pairs[pair][0] = 1                
                break
        

    # Go over pairs and try to see what happens if we drop the rich person
    # from the funded contract
    for pair in relevant_pairs:
        if relevant_pairs[pair][0] == 1: continue
        
        print("%d / %d tested.\t%d rejectable found." % (sum([x[0] for x in relevant_pairs.values()]), \
                                                         len(relevant_pairs), \
                                                         sum([x[1] for x in relevant_pairs.values()])))
        
        cid_f = pair[0]
        c_f = contracts[cid_f]
        sid_reject = pair[1]
        c_nf = programs[c_f.pid].nf
        cid_nf = c_nf.ID

        curr_scores_f = [c_f.scores_dic[sid] for sid in contract_to_match[cid_f]]
        curr_scores_nf = [c_nf.scores_dic[sid] for sid in contract_to_match[cid_nf]]

        # Override score
        backup_sid_reject_score = c_f.scores_dic[sid_reject]
        backup_sid_ind = [i for i in c_f.ranking if i[1] == sid_reject][0]
        c_f.scores_dic[sid_reject] = 0
        c_f.ranking = c_f.ranking[:backup_sid_ind] + c_f.ranking[backup_sid_ind+1:] + [c_f.ranking[backup_sid_ind]]
        c_f.ranking[-1][0] = 0

        # Re-run algorithm
        new_mu = spda_opt(students, contracts)

        new_scores_f = [c_f.scores_dic[sid] for sid in new_mu if new_mu[sid] == cid_f]
        new_scores_nf = [c_nf.scores_dic[sid] for sid in new_mu if new_mu[sid] == cid_nf]


        # We assume here that scores are the same across contracts!
        curr_scores = sorted(curr_scores_f + curr_scores_nf)
        new_scores = sorted(new_scores_f + new_scores_nf)

        success = True

        if curr_scores == new_scores:
            success = False
            curr_scores = []
        
        while curr_scores != []:
            a = min(curr_scores)
            lb = [b for b in new_scores if b >= a]
            if len(lb) == 0:
                success = False
                break
            curr_scores.remove(a)
            new_scores.remove(min(lb))

        if success:
            relevant_pairs[pair][1] = 1

        # Return to tie-broken score
        c_f.scores_dic[sid_reject] = backup_sid_reject_score
        c_f.ranking[-1][0] = backup_sid_reject_score
        c_f.ranking = c_f.ranking[:backup_sid_ind] + [c_f.ranking[-1]] + c_f.ranking[backup_sid_ind:]

        # Save to file
        relevant_pairs[pair][0] = 1
        sum_tested = sum([x[0] for x in relevant_pairs.values()])
        if (sum_tested % 20) == 0 or sum_tested == len(relevant_pairs):
            pickle_dump(relevant_pairs, pairs_fname)


    # Return to original scores
    for c in contracts.values():
        for sid in c.scores_dic:
            c.scores_dic[sid] = backup_scores_dic[c.ID][sid]

    return [pair for pair in relevant_pairs if relevant_pairs[pair][1]]


#############################################################
# Get list of blocking pairs
#############################################################

def get_blocking_pairs(students, contracts, programs, mu, check_rejects = []):
    """Get list of blocking pairs."""
    
    # Create implied cutoffs for all contracts
    contracts_to_students = {}
    for cid in contracts: contracts_to_students[cid] = set()
    for sid in students:
        if mu[sid] != None: contracts_to_students[mu[sid]].add(sid)
        
    cutoffs = {}
    for c in contracts.values():
        if c.capacity == 0: cutoffs[c.ID] = 10**8
        elif len(contracts_to_students[c.ID]) < c.capacity: cutoffs[c.ID] = -1
        else: cutoffs[c.ID] = min([c.scores_dic[sid] for sid in contracts_to_students[c.ID]])
    

    blocking_pairs = []
    
    for s in students.values():
        sid = s.ID
        check_len = len(s.ranking)
        ignore_list = []
        if mu[sid] != None:
            check_len = s.ranking.index(mu[sid])
            ignore_list = [x.ID for x in programs[contracts[mu[sid]].pid].cs]

        for cid in s.ranking[:check_len]:
            c = contracts[cid]
            if c.scores_dic[sid] < UNACCEPTABLE_SCORE: continue
            if cid in ignore_list:
                if len(contracts_to_students[cid]) < c.capacity:
                    blocking_pairs.append((cid,sid))
                #continue
            elif c.scores_dic[sid] > cutoffs[cid]:
                blocking_pairs.append((cid,sid))

    if blocking_pairs != []:
        return blocking_pairs

    if len(check_rejects) > 0:

        # Student wanted funded contract, and there is no reason
        # for her not to be there
        used_cids = {}
        for sid in check_rejects:
            s = students[sid]
            if mu[sid] == None: continue
            check_len = s.ranking.index(mu[sid])
            rel_list = [x.ID for x in programs[contracts[mu[sid]].pid].cs]
            for cid in s.ranking[:check_len]:
                c = contracts[cid]
                if c.scores_dic[sid] < UNACCEPTABLE_SCORE: continue
                if cid in used_cids:
                    # If equal to free quota, stop
                    used_cap = len([x for x in mu if mu[x] == cid])
                    if used_cids[cid] >= c.capacity - used_cap:
                        continue

                if cid in rel_list:
                    if len(contracts_to_students[cid]) < c.capacity:
                           blocking_pairs.append((cid,sid))
                           used_cids[cid] = used_cids.get(cid,0) + 1
        
        # Search for two students matched to the funded and non-funded contract
        # such that they would love to switch places
        for p in programs.values():
            if len(p.cs) != 2: continue
            c_f = p.f
            cid_f = c_f.ID
            c_nf = p.nf
            cid_nf = c_nf.ID

            if cid_f in used_cids:
                continue

            found_flag = False
            
            for sid1 in contracts_to_students[cid_f]:
                s1 = students[sid1]
                if cid_nf not in s1.ranking: continue
                if s1.ranking.index(cid_f) < s1.ranking.index(cid_nf): continue
                if sid1 not in c_nf.scores_dic: continue
                if c_nf.scores_dic[sid1] < UNACCEPTABLE_SCORE: continue

                for sid2 in contracts_to_students[cid_nf]:
                    if sid2 not in check_rejects: continue
                    s2 = students[sid2]
                    if cid_f not in s2.ranking: continue
                    if s2.ranking.index(cid_nf) < s2.ranking.index(cid_f): continue
                    if sid2 not in c_f.scores_dic: continue
                    if c_f.scores_dic[sid2] < UNACCEPTABLE_SCORE: continue

                    blocking_pairs.append((cid_f, sid2))
                    used_cids[cid_f] = 1
                    found_flag = True
                    break

                if found_flag: break
                    
                
    return blocking_pairs


#############################################################
# Backup the original ranking before running algorithms
#############################################################
rankings_backup = {}
for s in students.values():
    rankings_backup[s.ID] = s.ranking[:]

#############################################################
# Algorithm 1 - Try all rejections, and gradually remove
#############################################################

def reject_and_run(students, contracts, rejects, seed, start_ms = None):
    """Reject students from specific contracts, and run SPDA."""

    STB(students, contracts, backup_scores_dic, seed)

    # Go over rejects and remove the contract from the ranking
    for pair in rejects:
        cid = pair[0]
        sid = pair[1]
        s = students[sid]
        if cid in s.ranking:
            s.ranking.remove(cid)

    # Run DA
    if start_ms == None:
        new_mu = spda_opt(students, contracts)
    else:
        new_mu = spda_opt(students, contracts, start_ms, False)

    # Go back to original scores and rankings
    for c in contracts.values():
        c.scores_dic = deepcopy(backup_scores_dic[c.ID])
    
    for sid in {x[1] for x in rejects}:
        s = students[sid]
        s.ranking = rankings_backup[sid][:]

    # Return the resulting matching
    return new_mu



def algorithm1(students, contracts, programs, MP_pairs, seed, remove_mode = REMOVE_MODE):
    """Alternate algorithm 1."""

    rejects = MP_pairs[:]

    if AVOID_DUAL_PROGRAMS:
        rejects = [x for x in rejects if not contracts[x[0]].dual]

        
    ignore_cids = set()
    
    print("Starting algorithm1 with %d rejects" % len(rejects))

    new_mu = reject_and_run(students, contracts, rejects, seed)
    sid_rejects = {x[1] for x in rejects}
    blocking_pairs = get_blocking_pairs(students, contracts, programs, new_mu, sid_rejects)
    
    while len(blocking_pairs) != 0:
        if remove_mode == REMOVE_ALL:
            remove_sids = [x[1] for x in blocking_pairs]
        elif remove_mode == REMOVE_HIGHEST_ONE:
            bp = sorted(blocking_pairs)[0]
            rel_bps = [x for x in blocking_pairs if x[0] == bp[0]]
            ranks = [[contracts[x[0]].scores_dic[x[1]],x[1]] for x in rel_bps]
            ranks = sorted(ranks)
            remove_sids = [ranks[-1][1]]
        elif remove_mode == REMOVE_HIGHEST_EACH_CONTRACT:
            remove_sids = []
            for cid in {x[0] for x in blocking_pairs}:
                rel_bps = [x for x in blocking_pairs if x[0] == cid]
                ranks = [[contracts[x[0]].scores_dic[x[1]],x[1]] for x in rel_bps]
                ranks = sorted(ranks)
                remove_sids.append(ranks[-1][1])
        else:
            raise Exception("ERROR: remove_mode \"%s\" is illegal" % (str(remove_mode)))
            
            
        old_len = len(rejects)
        rejects = [x for x in rejects if x[1] not in remove_sids]
        new_len = len(rejects)
        if old_len == new_len:
            to_remove = int(ceil(len(rejects)*0.1))
            rejects = rejects[to_remove:]

        print "%d blocking pairs found, len_before = %d, len_after = %d" % \
              (len(blocking_pairs), old_len, len(rejects))

        new_mu = reject_and_run(students, contracts, rejects, seed)
        sid_rejects = {x[1] for x in rejects}
        blocking_pairs = get_blocking_pairs(students, contracts, programs, new_mu, sid_rejects)

    return new_mu


#################################################################
# Algorithm 2 - Flip all funded-non-funded consecutive rankings
#################################################################

def algorithm2(students, contracts, programs, seed, remove_mode = REMOVE_MODE):
    """Alternate algorithm 2."""
    
    print("Starting algorithm2")

    STB(students, contracts, backup_scores_dic, seed)

    flipped = set()

    # Flip consecutive contracts if needed
    for sid in students:
        s = students[sid]
        ranking_cs = [contracts[cid] for cid in s.ranking]
        for i in range(len(s.ranking)-1):
            if ranking_cs[i].pid == ranking_cs[i+1].pid and \
               ranking_cs[i].funded == 1:
                if AVOID_DUAL_PROGRAMS and ranking_cs[i].dual: continue
                flipped.add(sid)
                s.ranking[i] = ranking_cs[i+1].ID
                s.ranking[i+1] = ranking_cs[i].ID

    flipped_rankings = {}
    for sid in flipped:
        flipped_rankings[sid] = students[sid].ranking[:]

    new_mu = spda_opt(students, contracts)

    for sid in students:
        students[sid].ranking = rankings_backup[sid][:]

    bpairs = get_blocking_pairs(students, contracts, programs, new_mu, flipped)

    while len(bpairs) > 0:
        len_before = len(flipped)

        rel_bpairs = [x for x in bpairs if x[1] in flipped]

        if remove_mode == REMOVE_ALL:
            remove_sids = [x[1] for x in rel_bpairs]
        elif remove_mode == REMOVE_HIGHEST_ONE:
            bp = sorted(rel_bpairs)[0]
            rel_bps = [x for x in rel_bpairs if x[0] == bp[0]]
            ranks = [[contracts[x[0]].scores_dic[x[1]],x[1]] for x in rel_bps]
            ranks = sorted(ranks)
            remove_sids = [ranks[-1][1]]
        elif remove_mode == REMOVE_HIGHEST_EACH_CONTRACT:
            remove_sids = []
            for cid in {x[0] for x in rel_bpairs}:
                rel_bps = [x for x in rel_bpairs if x[0] == cid]
                ranks = [[contracts[x[0]].scores_dic[x[1]],x[1]] for x in rel_bps]
                ranks = sorted(ranks)
                remove_sids.append(ranks[-1][1])
        else:
            raise Exception("ERROR: remove_mode \"%s\" is illegal" % (str(remove_mode)))
            
        for sid in remove_sids:
            if sid in flipped: flipped.remove(sid)
                
        len_after = len(flipped)
        if len_before == len_after:
            to_remove = int(ceil(len(flipped)*0.1))
            to_remove_students = list(flipped)[:to_remove]
            flipped = {sid for sid in flipped if sid not in to_remove_students}
            len_after = len(flipped)
        
        print "%d blocking pairs found, len_before = %d, len_after = %d" % \
              (len(bpairs), len_before, len_after)

        for sid in flipped:
            students[sid].ranking = flipped_rankings[sid][:]
        
        new_mu = spda_opt(students, contracts)

        for sid in students:
            students[sid].ranking = rankings_backup[sid][:]
        bpairs = get_blocking_pairs(students, contracts, programs, new_mu, flipped)


    # Check stability of new_mu
    if get_blocking_pairs(students, contracts, programs, new_mu) != []:
        print("WARNING: algorithm2 returned non-stable matching!")

    return new_mu


#################################################################
# Algorithm 3 - Combine algorithm1 and algorithm2
# This means flipping for the conssecutive guys, and just trying
# for the non-consecutive ones.
#################################################################

def algorithm3(students, contracts, programs, MP_pairs, seed, remove_mode = REMOVE_MODE):
    """Alternate algorithm 3."""
    
    print("Starting algorithm3")

    old_mu = reject_and_run(students, contracts, [], seed)

    # Flip consecutive contracts if needed
    flipped = set()
    for sid in students:
        s = students[sid]
        ranking_cs = [contracts[cid] for cid in s.ranking]
        for i in range(len(s.ranking)-1):
            if ranking_cs[i].pid == ranking_cs[i+1].pid and \
               ranking_cs[i].funded == 1:
                if AVOID_DUAL_PROGRAMS and ranking_cs[i].dual: continue
                flipped.add(sid)
                s.ranking[i] = ranking_cs[i+1].ID
                s.ranking[i+1] = ranking_cs[i].ID

    flipped_rankings = {}
    for sid in flipped: flipped_rankings[sid] = students[sid].ranking[:]

    # See we have left from the MP_pairs
    rejects = []
    for pair in MP_pairs:
        cid_f = pair[0]
        sid = pair[1]
        c_f = contracts[cid_f]
        cid_nf = programs[c_f.pid].nf.ID
        s = students[sid]

        if cid_f in s.ranking and s.ranking.index(cid_f) + 1 < s.ranking.index(cid_nf):
            rejects.append(pair)

    if AVOID_DUAL_PROGRAMS:
        rejects = [x for x in rejects if not contracts[x[0]].dual]

    print("Starting algorithm3 with %d rejects" % len(rejects))

    new_mu = reject_and_run(students, contracts, rejects, seed)

    for sid in students:
        students[sid].ranking = rankings_backup[sid][:]

    sid_rejects = {x[1] for x in rejects}.union(flipped)
    bpairs = get_blocking_pairs(students, contracts, programs, new_mu, sid_rejects)
    
    while len(bpairs) != 0:        
        if remove_mode == REMOVE_ALL:
            remove_sids = [x[1] for x in bpairs]
        elif remove_mode == REMOVE_HIGHEST_ONE:
            bp = sorted(bpairs)[0]
            rel_bps = [x for x in bpairs if x[0] == bp[0]]
            ranks = [[contracts[x[0]].scores_dic[x[1]],x[1]] for x in rel_bps]
            ranks = sorted(ranks)
            remove_sids = [ranks[-1][1]]
        elif remove_mode == REMOVE_HIGHEST_EACH_CONTRACT:
            remove_sids = []
            for cid in {x[0] for x in bpairs}:
                rel_bps = [x for x in bpairs if x[0] == cid]
                ranks = [[contracts[x[0]].scores_dic[x[1]],x[1]] for x in rel_bps]
                ranks = sorted(ranks)
                remove_sids.append(ranks[-1][1])
        else:
            raise Exception("ERROR: remove_mode \"%s\" is illegal" % (str(remove_mode)))
            
        old_len_rejects = len(rejects)
        rejects = [x for x in rejects if x[1] not in remove_sids]
        new_len_rejects = len(rejects)

        old_len_flipped = len(flipped)

        rel_bpairs = [x for x in bpairs if x[1] in flipped]
        
        if remove_mode == REMOVE_ALL:
            remove_sids = {x[1] for x in rel_bpairs}
        elif remove_mode == REMOVE_HIGHEST_ONE:
            bp = sorted(rel_bpairs)[0]
            rel_bps = [x for x in rel_bpairs if x[0] == bp[0]]
            ranks = [[contracts[x[0]].scores_dic[x[1]],x[1]] for x in rel_bps]
            ranks = sorted(ranks)
            remove_sids = [ranks[-1][1]]
        elif remove_mode == REMOVE_HIGHEST_EACH_CONTRACT:
            remove_sids = []
            for cid in {x[0] for x in rel_bpairs}:
                rel_bps = [x for x in rel_bpairs if x[0] == cid]
                ranks = [[contracts[x[0]].scores_dic[x[1]],x[1]] for x in rel_bps]
                ranks = sorted(ranks)
                remove_sids.append(ranks[-1][1])
        else:
            raise Exception("ERROR: remove_mode \"%s\" is illegal" % (str(remove_mode)))
            
        for sid in remove_sids:
            if sid in flipped: flipped.remove(sid)
        
        new_len_flipped = len(flipped)

        if old_len_rejects == new_len_rejects and \
           old_len_flipped == new_len_flipped:
            if len(rejects) > 0 and random.random()>0.5:
                # Remove one random elements from rejected
                rejected.remove(random.choice(rejects))
            else:
                # Remove one random elements from flipped
                flipped.remove(random.choice(flipped))

        print "%d blocking pairs found, len(A) = %d, len(A') = %d, len(B) = %d, len(B') = %d" % \
              (len(bpairs), old_len_rejects, new_len_rejects, \
               old_len_flipped, new_len_flipped)


        for sid in flipped:
            students[sid].ranking = flipped_rankings[sid][:]

        new_mu = reject_and_run(students, contracts, rejects, seed)
        for sid in students:
            students[sid].ranking = rankings_backup[sid][:]

        sid_rejects = {x[1] for x in rejects}.union(flipped)
        bpairs = get_blocking_pairs(students, contracts, programs, new_mu, sid_rejects)
        

    for sid in students:
        students[sid].ranking = rankings_backup[sid][:]

    # Check stability of new_mu
    if get_blocking_pairs(students, contracts, programs, new_mu) != []:
        print("WARNING: algorithm3 returned non-stable matching!")

    return new_mu


##################################
# Running
##################################

def create_all_outputs():
    """Create all outputs relevant to the paper."""

    for step in range(3):
        if step == 0:
            ABC_REMOVE_STRING = "abc_remove_00"
        elif step == 1:
            remove_ABC()
            ABC_REMOVE_STRING = "abc_remove_10"
        elif step == 2:
            remove_ABC_caps()
            ABC_REMOVE_STRING = "abc_remove_11"

        # Create backup
        # Go program by program, and for each try to reject a certain person,
        # run the algorithm and see if the allocation was improved
        global backup_scores_dic, backup_ranking
        backup_scores_dic = {}
        backup_ranking = {}
        for c in contracts.values():
            backup_scores_dic[c.ID] = deepcopy(c.scores_dic)
            backup_ranking[c.ID] = deepcopy(c.ranking)

    
        # Set all to fixed_real_capacity, and run basic analysis
        for c in contracts.values():
            c.capacity = c.fixed_real_capacity

        # Print out the capacities used
        f = open("capacities_" + ABC_REMOVE_STRING + ".csv", "w")
        f.write("cid,capacity\n")
        for c in contracts.values():
            f.write("%s,%d\n" % (c.ID, c.capacity))
        f.close()

        # Run without tie-breaking
        mu = cutoffs_alg_no_tb(students, contracts)
            

        # Run tie-breaking
        STB(students, contracts, backup_scores_dic, GLOBAL_SEED)

        mu_srda = srda(students, contracts)
        mu_spda = spda_opt(students, contracts)

        MP_pairs = get_potential_rejects(True, GLOBAL_SEED, \
                                         "fixed_real_capacity_" + \
                                         ABC_REMOVE_STRING)
        mu1 = algorithm1(students, contracts, programs, MP_pairs, GLOBAL_SEED)
        mu2 = algorithm2(students, contracts, programs, GLOBAL_SEED)
        mu3 = algorithm3(students, contracts, programs, MP_pairs, GLOBAL_SEED)

        for mu_name in ["mu_srda", "mu_spda", "mu1", "mu2", "mu3"]:
            print "size(%s) = %d" % (mu_name, \
                                     eval("len([x for x in %s if %s[x] != None])" % \
                                          (mu_name, mu_name)))

        # Save all results to file.
        f = open("stable_matchings_" + ABC_REMOVE_STRING + ".csv", "w")
        f.write("sid,mu_srda,mu_spda,mu1,mu2,mu3\n")
        for sid in sorted(mu_srda.keys()):
            f.write("%s,%s,%s,%s,%s,%s\n" % (sid, str(mu_srda[sid]), \
                                             str(mu_spda[sid]), \
                                             str(mu1[sid]), \
                                             str(mu2[sid]), str(mu3[sid])))
        f.close()


create_all_outputs()
